import matplotlib
import matplotlib.pyplot as plt
import numpy as np

#
acc = np.load('./save/1/train_loss_4.npy', allow_pickle=True)
acc1 = np.load('./save/1/train_loss_4.npy', allow_pickle=True)

acc2 = np.load('./save/27/fed_avg/train_loss_4.npy', allow_pickle=True)
acc3 = np.load('./save/30/fed_mutual/train_loss_4.npy', allow_pickle=True)

#
# acc = np.load('./1.npy', allow_pickle=True)
# acc1 = np.load('./2.npy', allow_pickle=True)
print(acc[-1])
print(acc1[-1])
plt.figure()
# plt.subplot(2, 1, 1)
plt.title('Loss of Global Model over Testset ')
plt.plot(range(len(acc)), acc1, 'r')
plt.plot(range(len(acc)), acc, 'g', linestyle=':')

plt.plot(range(len(acc)), acc3, 'r')
plt.plot(range(len(acc)), acc2, 'g', linestyle=':')
# plt.xlabel('communication rounds')
plt.ylabel('Test loss')
plt.legend(['FML', "FedAvg"])
#
# plt.subplot(2, 1, 2)
# plt.plot(range(len(acc3)), acc3, 'r')
# plt.plot(range(len(acc2)), acc2, 'g', linestyle=':')
#
# plt.ylabel('Test loss')
# plt.xlabel('Communication Rounds')
#
#
# plt.legend(['FML,Non-IID', 'FedAvg,Non-IID'])
plt.tight_layout()
plt.savefig('./1.pdf')
print("绘图完毕")





#
#
# # ###################
# acc = np.load('./save/105/fed_avg/val_meme_acc_2.npy', allow_pickle=True)
# acc1 = np.load('./save/105/fed_mutual/val_acc_2.npy', allow_pickle=True)
#
# acc2 = np.load('./save/107/fed_avg/acc_R1000_Dcifar10_Afed_avg_GCNN_LCNN_I0.npy', allow_pickle=True)
# acc3 = np.load('./save/107/fed_mutual/acc_R1000_Dcifar10_Afed_mutual_GCNN_LCNN_I0.npy', allow_pickle=True)
#
# acc4 = np.load('./save/108/fed_avg/acc_R1000_Dmnist_Afed_avg_Gmlp_Lmlp_I0.npy', allow_pickle=True)
# acc5 = np.load('./save/108/fed_mutual/acc_R1000_Dmnist_Afed_mutual_Gmlp_Lmlp_I0.npy', allow_pickle=True)
#
# plt.figure()
# # plt.subplot(1, 2, 1)
# plt.title('Communication Efficiency ')
# plt.plot(range(len(acc1)), acc1, 'g')
# plt.plot(range(len(acc)), acc, 'r',linestyle=':')
# #
# # plt.plot(range(len(acc3)), acc3, 'r')
# # plt.plot(range(len(acc2)), acc2, 'g',linestyle=':')
# #
# # plt.plot(range(len(acc5)), acc5, 'k')
# # plt.plot(range(len(acc4)), acc4, 'g',linestyle=':')
#
# plt.ylabel('Global accuracy')
# plt.xlabel('Communication Rounds')
# # plt.legend(['fed_avg','fed_mutual'])
# # 'FedAvg,lenet5,cifar10','FML,cnn1,cifar10','FedAvg,cnn1,cifar10',
# plt.legend(['FML,lenet5,cifar10','FML,mlp,mnist','FedAvg,mlp,mnist'])
# plt.tight_layout()
# plt.savefig('./save/108/1.png')
# print("绘图完毕")